from collections import defaultdict, namedtuple
from math import log2
from sklearn import tree
import pydot

def split_dataset(dataset, classes, feat_idx):
    '''
    dataset: 待划分的数据集
    classes: 数据集对应的类型
    feat_idx: 特征在特征向量中的所有
    '''
    splited_dict = {}
    for data_vect, cls in zip(dataset, classes):
        feat_val = data_vect[feat_idx]
        sub_dataset, sub_classes = splited_dict.setdefault(feat_val, [[], []])
        sub_dataset.append(data_vect[: feat_idx] + data_vect[feat_idx+1:])
        sub_classes.append(cls)
    return splited_dict

def get_majority(classes):
    '''
        返回类型中占比最多的类型
    '''
    cls_num = defaultdict(lambda: 0)
    for cls in classes:
        cls_num[cls] +=1
    return max(cls_num, key = cls_num.get)

def get_shanno_entropy(values):
    '''
        根据给定列表中的值计算期信息熵
    '''
    uniq_vals = set(values)
    val_nums = {key: values.count(key) for key in uniq_vals}
    probs = [v/len(values) for k, v in val_nums.items()]
    entropy = sum([-prob * log2(prob) for prob in probs])
    return entropy

def choose_best_split_feature(dataset, classes):
    '''
    '''
    base_entropy = get_shanno_entropy(classes)
    feat_num = len(dataset[0])
    